package com.lw.leetcode.dp.c;

import java.util.HashSet;
import java.util.Set;
import java.util.function.BinaryOperator;

/**
 * Created with IntelliJ IDEA.
 * 2019. 解出数学表达式的学生分数
 *
 * @author liw
 * @version 1.0
 * @date 2022/8/17 14:28
 */
public class ScoreOfStudents {


    public static void main(String[] args) {
        ScoreOfStudents test = new ScoreOfStudents();

        // 7
//        String str = "7+3*1*2";
//        int[] arr = {20, 13, 42};

        // 19
        String str = "3+5*2";
        int[] arr = {13, 0, 10, 13, 13, 16, 16};

        // 10
//        String str = "6+0*1";
//        int[] arr = {12, 9, 6, 4, 8, 6};

        // 0
//        String str = "6+2*1+5*4+3*2+9*8+2*2+3*4+1*2+1";
//        int[] arr = {12, 9, 6, 4, 8, 6};

        // 0
//        String str = "4+8*8+4*8+4*8+8*4+4*4+4*4+8*4+8";
//        int[] arr = {12, 9, 6, 4, 8, 6};

        int i = test.scoreOfStudents(str, arr);
        System.out.println(i);
    }


    public int scoreOfStudents(String s, int[] answers) {
        int sum = getSum(s);
        int length = s.length();
        int a = length >> 1;
        int b = a + 1;
        int[] ops = new int[a];
        int[] nums = new int[b];
        nums[0] = s.charAt(0) - '0';
        for (int i = 1; i < length; i += 2) {
            int index = i >> 1;
            ops[index] = s.charAt(i) == '+' ? 0 : 1;
            nums[index + 1] = s.charAt(i + 1) - '0';
        }

        BinaryOperator<Integer>[] fs = new BinaryOperator[2];
        fs[0] = (t1, t2) -> t1 + t2;
        fs[1] = (t1, t2) -> t1 * t2;
        int l = a * a;
        Set<Integer>[] arr = new HashSet[l];
        for (int i = 0; i < a; i++) {
            Set<Integer> set = new HashSet<>();
            set.add( fs[ops[i]].apply(nums[i], nums[i + 1]));
            arr[i * a + i] = set;
        }
        for (int i = a - 1; i >= 0; i--) {
            for (int j = i + 1; j < a; j++) {
                int st = i * a + j;
                Set<Integer> set = new HashSet<>();
                int w = nums[i];
                Set<Integer> m = arr[(i + 1) * a + j];
                BinaryOperator<Integer> f = fs[ops[i]];
                for (int v : m) {
                    int apply = f.apply(w, v);
                    if (apply < 1001) {
                        set.add(apply);
                    }
                }
                w = nums[j + 1];
                m = arr[i * a + j - 1];
                f = fs[ops[j]];
                for (int v : m) {
                    int apply = f.apply(w, v);
                    if (apply < 1001) {
                        set.add(apply);
                    }
                }
                for (int k = i + 1; k <= j - 1; k++) {
                    Set<Integer> m1 = arr[i * a + k - 1];
                    Set<Integer> m2 = arr[(k + 1) * a + j];
                    f = fs[ops[k]];
                    for (int v1 : m1) {
                        for (int v2 : m2) {
                            int apply = f.apply(v1, v2);
                            if (apply < 1001) {
                                set.add(apply);
                            }
                        }
                    }
                }
                arr[st] = set;
            }
        }
        int value = 0;
        Set<Integer> m = arr[a - 1];
        int[] flags = new int[1001];
        for (int t : m) {
            flags[t] = 1;
        }
        for (int answer : answers) {
            if (answer == sum) {
                value += 5;
            } else if (flags[answer] == 1) {
                value += 2;
            }
        }
        return value;
    }

    private int getSum(String s) {
        int length = s.length();
        int sum = 0;
        int[] arr = new int[32];
        int index = 0;
        int i = 0;
        while (i < length) {
            char c = s.charAt(i);
            if (c == '+') {
                i++;
                arr[index++] = s.charAt(i) - '0';
            } else if (c == '*') {
                i++;
                int v = s.charAt(i) - '0';
                arr[index - 1] = arr[index - 1] * v;
            } else {
                arr[index++] = s.charAt(i) - '0';
            }
            i++;
        }
        for (int v : arr) {
            sum += v;
        }
        return sum;
    }

}
